#include <gtest/gtest.h>
#include <iostream>
#define protected public
#define private public
#include "framework/graph/core/op/op_desc.h"
#include "graph/op/const_defs.h"
#include "framework/tensor/tensor.h"
#include "graph/attr_value.h"
#include "framework/graph/core/node/node.h"
#include "framework/graph/core/node/node_compatibler.h"
#include "framework/graph/core/cgraph/compute_graph.h"
#include "framework/graph/utils/op_desc_utils.h"
#include "framework/graph/debug/ge_op_types.h"
#include "framework/graph/utils/graph_utils.h"
#include "framework/graph/utils/tensor_utils.h"
#include "framework/graph/utils/attr_utils.h"
#include "framework/graph/core/cgraph/graph_builder.h"
#include "framework/graph/core/cgraph/graph_modifier.h"
#include "framework/graph/core/cgraph/graph_finder.h"
#include "framework/graph/core/edge/edge.h"
#include "graph/op/all_ops.h"
#include "graph/core/op/constholder_op_desc.h"
#undef protected
#undef private
using namespace std;
using namespace ge;

class ge_test_opdesc_utils : public testing::Test {
protected:
    void SetUp()
    {
    }

    void TearDown()
    {
    }
};

TEST_F(ge_test_opdesc_utils, CreateOperatorFromDesc)
{
    OpDescPtr descPtr = std::make_shared<OpDesc>("name1", "type1");
    EXPECT_EQ(descPtr->AddInputDesc("x", TensorDesc(Shape({1, 16, 16, 16}), FORMAT_NCHW)), GRAPH_SUCCESS);
    EXPECT_EQ(descPtr->AddInputDesc("w", TensorDesc(Shape({1, 1, 1, 1}), FORMAT_NCHW)), GRAPH_SUCCESS);
    EXPECT_EQ(descPtr->AddOutputDesc("y", TensorDesc(Shape({1, 32, 8, 8}), FORMAT_NCHW)), GRAPH_SUCCESS);
    AttrValue test_attr = AttrValue::CreateFrom(static_cast<AttrValue::INT>(1));
    descPtr->SetAttr("test_attr", std::move(test_attr));

    Operator oprt = OpDescUtils::CreateOperatorFromOpDesc(descPtr);
    AttrValue out_attr;
    // oprt->GetAttr("test_attr", out_attr);
    AttrValue::INT out = out_attr.GetInt();
    EXPECT_EQ(out, 0);

    TensorDesc inputDesc1 = descPtr->GetInputDesc("x");
    EXPECT_TRUE(inputDesc1.GetShape().GetDimNum() == 4);
    EXPECT_TRUE(inputDesc1.GetShape().GetDim(0) == 1);
    EXPECT_TRUE(inputDesc1.GetShape().GetDim(1) == 16);
    EXPECT_TRUE(inputDesc1.GetShape().GetDim(2) == 16);
    EXPECT_TRUE(inputDesc1.GetShape().GetDim(3) == 16);

    TensorDesc inputDesc2 = descPtr->GetInputDesc(1);
    EXPECT_TRUE(inputDesc2.GetShape().GetDimNum() == 4);
    EXPECT_TRUE(inputDesc2.GetShape().GetDim(0) == 1);
    EXPECT_TRUE(inputDesc2.GetShape().GetDim(1) == 1);
    EXPECT_TRUE(inputDesc2.GetShape().GetDim(2) == 1);
    EXPECT_TRUE(inputDesc2.GetShape().GetDim(3) == 1);

    OpDescPtr outPtr = OpDescUtils::GetOpDescFromOperator(oprt);
    EXPECT_TRUE(outPtr == descPtr);

    string name1 = outPtr->GetName();
    string name2 = oprt.GetName();
    EXPECT_TRUE(name1 == name2);
}

TEST_F(ge_test_opdesc_utils, ClearInputDesc)
{
    OpDescPtr descPtr = std::make_shared<OpDesc>("name1", "type1");
    EXPECT_EQ(descPtr->AddInputDesc("x", TensorDesc(Shape({1, 16, 16, 16}), FORMAT_NCHW)), GRAPH_SUCCESS);
    EXPECT_EQ(descPtr->AddInputDesc("w", TensorDesc(Shape({1, 1, 1, 1}), FORMAT_NCHW)), GRAPH_SUCCESS);
    EXPECT_EQ(descPtr->AddOutputDesc("y", TensorDesc(Shape({1, 32, 8, 8}), FORMAT_NCHW)), GRAPH_SUCCESS);

    OpDescPtr descPtr2 = std::make_shared<OpDesc>("name2", "type2");
    EXPECT_EQ(descPtr2->AddInputDesc("x", TensorDesc(Shape({1, 16, 16, 16}), FORMAT_NCHW)), GRAPH_SUCCESS);
    EXPECT_EQ(descPtr2->AddInputDesc("w", TensorDesc(Shape({1, 1, 1, 1}), FORMAT_NCHW)), GRAPH_SUCCESS);
    EXPECT_EQ(descPtr2->AddOutputDesc("y", TensorDesc(Shape({1, 32, 8, 8}), FORMAT_NCHW)), GRAPH_SUCCESS);

    ComputeGraphPtr graphPtr = ge::ComputeGraph::Make("name");
    Node* n1 = graphPtr->ROLE(GraphModifier).AddNode(descPtr);
    Node* n2 = graphPtr->ROLE(GraphModifier).AddNode(descPtr);
    // EXPECT_TRUE(OpDescUtils::ClearInputDesc(*n1));
    EXPECT_TRUE(descPtr2->ClearInputDesc(0));
}

TEST_F(ge_test_opdesc_utils, SetOutNodeWeightDef)
{
    OpDescPtr descPtr = std::make_shared<OpDesc>("name1", "type1");
    EXPECT_EQ(descPtr->AddInputDesc("x", TensorDesc(Shape({1, 16, 16, 16}), FORMAT_NCHW)), GRAPH_SUCCESS);
    EXPECT_EQ(descPtr->AddOutputDesc("y1", TensorDesc(Shape({1, 1, 1, 1}), FORMAT_NCHW)), GRAPH_SUCCESS);
    EXPECT_EQ(descPtr->AddOutputDesc("y2", TensorDesc(Shape({1, 32, 8, 8}), FORMAT_NCHW)), GRAPH_SUCCESS);

    OpDescPtr descPtr2 = std::make_shared<OpDesc>("name2", "type2");
    EXPECT_EQ(descPtr2->AddInputDesc("x", TensorDesc(Shape({1, 16, 16, 16}), FORMAT_NCHW)), GRAPH_SUCCESS);
    EXPECT_EQ(descPtr2->AddOutputDesc("y", TensorDesc(Shape({1, 32, 8, 8}), FORMAT_NCHW)), GRAPH_SUCCESS);

    OpDescPtr descPtr3 = std::make_shared<OpDesc>("name3", "type3");
    EXPECT_EQ(descPtr3->AddInputDesc("x", TensorDesc(Shape({1, 16, 16, 16}), FORMAT_NCHW)), GRAPH_SUCCESS);
    EXPECT_EQ(descPtr3->AddOutputDesc("y", TensorDesc(Shape({1, 32, 8, 8}), FORMAT_NCHW)), GRAPH_SUCCESS);

    ComputeGraphPtr graphPtr = ge::ComputeGraph::Make("name");
    Node* n1 = graphPtr->ROLE(GraphModifier).AddNode(descPtr);
    Node* n2 = graphPtr->ROLE(GraphModifier).AddNode(descPtr2);
    Node* n3 = graphPtr->ROLE(GraphModifier).AddNode(descPtr3);

    vector<ge::TensorPtr> weights;
    float w[1] = {1.0};
    TensorDesc tensorDesc(Shape({1}));
    TensorPtr tensor = std::make_shared<Tensor>(tensorDesc, (const uint8_t*)w, 1 * sizeof(float));
    weights.push_back(tensor);
    EXPECT_EQ(GRAPH_SUCCESS, OpDescUtils::SetWeights(*n1, weights));
    EXPECT_EQ(ge::GRAPH_SUCCESS, OpDescUtils::SetOutNodeWeightDef(*n1, weights));
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*n1, 0}, {*n2, 0});
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*n1, 1}, {*n3, 0});
    EXPECT_NE(ge::GRAPH_SUCCESS, OpDescUtils::SetOutNodeWeightDef(*n1, weights));

    weights.clear();
    EXPECT_NE(ge::GRAPH_SUCCESS, OpDescUtils::SetOutNodeWeightDef(*n1, weights));
}

TEST_F(ge_test_opdesc_utils, MutableWeights)
{
    OpDescPtr descPtr = std::make_shared<OpDesc>("name1", "type1");
    EXPECT_EQ(descPtr->AddInputDesc("x", TensorDesc(Shape({1, 16, 16, 16}), FORMAT_NCHW)), GRAPH_SUCCESS);
    EXPECT_EQ(descPtr->AddInputDesc("w", TensorDesc(Shape({1, 1, 1, 1}), FORMAT_NCHW)), GRAPH_SUCCESS);
    EXPECT_EQ(descPtr->AddOutputDesc("y", TensorDesc(Shape({1, 32, 8, 8}), FORMAT_NCHW)), GRAPH_SUCCESS);

    OpDescPtr descPtr2 = std::make_shared<OpDesc>("name2", "type2");
    EXPECT_EQ(descPtr2->AddInputDesc("x", TensorDesc(Shape({1, 16, 16, 16}), FORMAT_NCHW)), GRAPH_SUCCESS);
    EXPECT_EQ(descPtr2->AddInputDesc("w", TensorDesc(Shape({1, 1, 1, 1}), FORMAT_NCHW)), GRAPH_SUCCESS);
    EXPECT_EQ(descPtr2->AddOutputDesc("y", TensorDesc(Shape({1, 32, 8, 8}), FORMAT_NCHW)), GRAPH_SUCCESS);

    ComputeGraphPtr graphPtr = ge::ComputeGraph::Make("name");
    Node* n1 = graphPtr->ROLE(GraphModifier).AddNode(descPtr);
    Node* n2 = graphPtr->ROLE(GraphModifier).AddNode(descPtr);

    float f[1] = {1.0};
    TensorDesc tensorDesc(Shape({1}));
    TensorPtr tensor = std::make_shared<Tensor>(tensorDesc, (const uint8_t*)f, 1 * sizeof(float));

    OpDescPtr nullOpDesc = nullptr;

    EXPECT_EQ(GRAPH_FAILED, OpDescUtils::SetWeights(descPtr, nullptr));
    EXPECT_EQ(GRAPH_SUCCESS, OpDescUtils::SetWeights(descPtr, tensor));
    EXPECT_EQ(GRAPH_SUCCESS, OpDescUtils::SetWeights(*descPtr2.get(), tensor));
    EXPECT_EQ(GRAPH_FAILED, OpDescUtils::SetWeights(*descPtr2.get(), nullptr));

    EXPECT_NE(nullptr, OpDescUtils::MutableWeights(descPtr));

    EXPECT_EQ(nullptr, OpDescUtils::MutableWeights(nullOpDesc));

    auto tensorVec = OpDescUtils::GetWeights(*n1);
    EXPECT_EQ(0, tensorVec.size());
}
TEST_F(ge_test_opdesc_utils, Operator)
{
    OpDescPtr descPtr = std::make_shared<OpDesc>("name1", "type1");
    EXPECT_EQ(descPtr->AddInputDesc("x", TensorDesc(Shape({1, 16, 16, 16}), FORMAT_NCHW)), GRAPH_SUCCESS);
    auto op = OpDescUtils::CreateOperatorFromOpDesc(descPtr);
    EXPECT_EQ(descPtr->AddInputDesc("w", TensorDesc(Shape({1, 1, 1, 1}), FORMAT_NCHW)), GRAPH_SUCCESS);
    EXPECT_EQ(descPtr->AddOutputDesc("y", TensorDesc(Shape({1, 32, 8, 8}), FORMAT_NCHW)), GRAPH_SUCCESS);
    OpDescUtils::GetOpDescFromOperator(op);
    bool ret = descPtr->ClearInputDesc(0);
    EXPECT_EQ(ret, true);
    OpDescPtr descPtr2 = std::make_shared<OpDesc>("name2", "type2");
    EXPECT_EQ(descPtr2->AddInputDesc("x", TensorDesc(Shape({1, 16, 16, 16}), FORMAT_NCHW)), GRAPH_SUCCESS);
    EXPECT_EQ(descPtr2->AddInputDesc("w", TensorDesc(Shape({1, 1, 1, 1}), FORMAT_NCHW)), GRAPH_SUCCESS);
    EXPECT_EQ(descPtr2->AddOutputDesc("y", TensorDesc(Shape({1, 32, 8, 8}), FORMAT_NCHW)), GRAPH_SUCCESS);
    ret = OpDescUtils::HasSparseAlgorithmParams(descPtr2);
    EXPECT_EQ(ret, false);
    SparseAlgorithmParams coordGrid;
    OpDescUtils::SetSparseAlgorithmParams(descPtr2, coordGrid);
    OpDescUtils::GetSparseAlgorithmParams(descPtr2, coordGrid);
    TensorDesc tensorDesc(Shape({1}));
    float f[1] = {1.0};
    TensorPtr tensor = std::make_shared<Tensor>(tensorDesc, (const uint8_t*)f, 1 * sizeof(float));
    ComputeGraphPtr graphPtr = ge::ComputeGraph::Make("name");
    Node* node = graphPtr->ROLE(GraphModifier).AddNode(descPtr2);
    OpDescUtils::SetWeights(descPtr2, tensor);
    OpDescUtils::GetWeightsWithNoConst(*node);
    OpDescUtils::GetConstInputNames(*node);
    OpDescUtils::CreateOperatorFromOpDesc(nullptr);
    OpDescUtils::MutableWeights(*node);
}
TEST_F(ge_test_opdesc_utils, Operator_two)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");
    OpDescPtr descPtr = std::make_shared<OpDesc>("name1", "type1");
    ge::Shape tensorShape;
    descPtr->AddInputDesc(ge::TensorDesc(tensorShape, ge::FORMAT_ND));
    Node* n1 = graphPtr->ROLE(GraphModifier).AddNode(descPtr);

    TensorDesc tensorDesc(Shape({1}));
    float f[1] = {1.0};
    TensorPtr tensor_ptr = std::make_shared<Tensor>(tensorDesc, (const uint8_t*)f, 1 * sizeof(float));
    auto ret = n1->ROLE(NodeCompatibler).TransTensorToConstInput(tensor_ptr, 0);

    EXPECT_EQ(ret, GRAPH_SUCCESS);
    const int32_t batchSize = 3;
    const int32_t classesNum = 4;
    const int32_t samplesNum = 10;

    TensorDesc tensorDesc1(Shape({batchSize, classesNum}));
    tensorDesc1.SetDataType(DT_FLOAT);

    TensorDesc tensorDesc2(Shape({1}));
    tensorDesc2.SetDataType(DT_UINT8);

    OpDescPtr opDescPtr = std::make_shared<OpDesc>("Multinomial", "Multinomial");
    ge::AttrUtils::SetInt(opDescPtr, "min_size_num", 2);
    opDescPtr->AddInputDesc(tensorDesc1);
    TensorDesc outDesc;
    opDescPtr->AddOutputDesc(outDesc);
    Node* nodePtr = graphPtr->ROLE(GraphModifier).AddNode(opDescPtr);

    EXPECT_EQ(opDescPtr->HasAttr("min_size_num"), true);
    uint32_t value;
    bool ret1 = AttrUtils::GetInt(opDescPtr, "min_size_num", value);
    EXPECT_EQ(ret1, true);
    EXPECT_EQ(value, 2);
    const vector<ge::TensorPtr> weights;
    OpDescUtils::SetWeights(*nodePtr, weights);
    ret = OpDescUtils::ClearWeights(*nodePtr);
    EXPECT_EQ(ret, GRAPH_SUCCESS);
    ret = OpDescUtils::ClearWeights(*nodePtr);
    EXPECT_EQ(ret, GRAPH_SUCCESS);
}
TEST_F(ge_test_opdesc_utils, AttrUtils)
{
    OpDescPtr opDescPtr = std::make_shared<OpDesc>("Multinomial", "Multinomial");
    std::initializer_list<int64_t> value {2, 3, 4};
    vector<int64_t> value_tmp;
    bool ret = AttrUtils::SetListInt(opDescPtr, "min_size_num_list", value);
    EXPECT_EQ(AttrUtils::SetListInt(nullptr, "min_size_num_list", value), false);
    EXPECT_EQ(ret, true);
    ret = AttrUtils::GetListInt(opDescPtr, "min_size_num_list", value_tmp);
    EXPECT_EQ(AttrUtils::GetListInt(nullptr, "min_size_num_list", value_tmp), false);
    EXPECT_EQ(ret, true);
    EXPECT_EQ(value_tmp[0], 2);
    std::initializer_list<int32_t> value1 {2, 3, 4};
    vector<int32_t> value_tmp1;
    ret = AttrUtils::SetListInt(opDescPtr, "min_size_num_list", value1);
    EXPECT_EQ(AttrUtils::SetListInt(nullptr, "min_size_num_list", value1), false);
    EXPECT_EQ(ret, true);
    ret = AttrUtils::GetListInt(opDescPtr, "min_size_num_list", value_tmp1);
    EXPECT_EQ(AttrUtils::GetListInt(nullptr, "min_size_num_list", value_tmp1), false);
    EXPECT_EQ(ret, true);
    EXPECT_EQ(value_tmp1[0], 2);
    std::initializer_list<uint32_t> value2 {2, 3, 4};
    vector<uint32_t> value_tmp2;
    ret = AttrUtils::SetListInt(opDescPtr, "min_size_num_list", value2);
    EXPECT_EQ(AttrUtils::SetListInt(nullptr, "min_size_num_list", value2), false);
    vector<TensorPtr> value_tmp4;
    EXPECT_EQ(AttrUtils::GetListTensor(nullptr, "min_size_num_list", value_tmp4), false);
    EXPECT_EQ(AttrUtils::GetListTensor(nullptr, "Multinomial", value_tmp4), false);
    EXPECT_EQ(ret, true);
    ret = AttrUtils::GetListInt(opDescPtr, "min_size_num_list", value_tmp2);
    EXPECT_EQ(AttrUtils::GetListInt(nullptr, "min_size_num_list", value_tmp2), false);
    EXPECT_EQ(ret, true);
    EXPECT_EQ(value_tmp2[0], 2);

    ge::AttrUtils::SetInt(opDescPtr, "min_size_num", 2);
    EXPECT_EQ(AttrUtils::SetInt(nullptr, "min_size_num", 2), false);
    int64_t value3;
    ge::AttrUtils::GetInt(opDescPtr, "min_size_num", value3);
    EXPECT_EQ(AttrUtils::GetInt(nullptr, "min_size_num", value3), false);
    int32_t value4;
    ge::AttrUtils::GetInt(opDescPtr, "min_size_num", value4);
    EXPECT_EQ(AttrUtils::GetInt(nullptr, "min_size_num", value4), false);
    uint32_t value5;
    ge::AttrUtils::GetInt(opDescPtr, "min_size_num", value5);
    EXPECT_EQ(AttrUtils::GetInt(nullptr, "min_size_num", value5), false);
    EXPECT_EQ(value3, 2);
    EXPECT_EQ(value4, 2);
    EXPECT_EQ(value5, 2);

    AttrUtils::CloneOpDesc(opDescPtr);
    AttrUtils::CloneOpDesc(nullptr);

    TensorPtr value_tensor;
    AttrUtils::GetTensor(nullptr, "Multinomial", value_tensor);
    OpDescUtils::GetWeights(nullptr);
    OpDescUtils::GetNonConstInputsSize(nullptr);
    OpDescUtils::GetNonConstOutputsSize(nullptr);
    std::shared_ptr<ComputeGraph> graph_ptr = ge::ComputeGraph::Make("ut_test_graph");
    OpDescPtr opdesc_ptr = std::make_shared<OpDesc>("test", "test");
    Node* node_ptr_return = graph_ptr->ROLE(GraphModifier).AddNode(opdesc_ptr);
    OpDescUtils::GetNonConstInputTensorDesc(*node_ptr_return, 3);
    ConstNodePtr constnodeptr = nullptr;
    OpDescUtils::GetNonConstInputTensorDesc(constnodeptr, 4);
    size_t indexNonConst = 3;
    size_t index = 0;
    EXPECT_EQ(OpDescUtils::GetNonConstInputIndex(*node_ptr_return, indexNonConst, index), false);
    EXPECT_EQ(OpDescUtils::IsNonConstInput(*node_ptr_return, 1), false);
    EXPECT_EQ(OpDescUtils::GetNonConstInputIndex(constnodeptr, indexNonConst, index), false);
    EXPECT_EQ(OpDescUtils::IsNonConstInput(constnodeptr, 1), false);
    OpDescUtils::GetConstInputs(constnodeptr);
    OpDescUtils::GetNonConstTensorDesc(constnodeptr);
    OpDescUtils::GetNonConstOutputsSize(*node_ptr_return);
    OpDescUtils::GetNonConstTensorDesc(*node_ptr_return);
}

TEST_F(ge_test_opdesc_utils, ConstHolder_MutableWeights_Test)
{
    OpDescPtr descPtr = std::make_shared<ConstHolderOpDesc>("name1", "Const");
    EXPECT_EQ(descPtr->AddOutputDesc("y", TensorDesc(Shape({1, 32, 8, 8}), FORMAT_NCHW)), GRAPH_SUCCESS);

    ComputeGraphPtr graphPtr = ge::ComputeGraph::Make("name");
    Node* n1 = graphPtr->ROLE(GraphModifier).AddNode(descPtr);

    float f[1] = {1.0};
    TensorDesc tensorDesc(Shape({1}));
    TensorPtr tensor = std::make_shared<Tensor>(tensorDesc, (const uint8_t*)f, 1 * sizeof(float));

    EXPECT_EQ(OpDescUtils::SetWeights(descPtr, tensor), GRAPH_SUCCESS);
    vector<TensorPtr> weights = OpDescUtils::MutableWeights(*n1);
    EXPECT_EQ(weights.size(), 1);
}